"""
Author: xuzhougeng
Affiliation: SIPPE
Aim: A snakemake workflow to analysis ATAC-seq
Run: snakemake -s Snakefile
reference
- https://docs.google.com/document/d/1f0Cm4vRyDQDu0bMehHD7P7KOMxTOP-HiNoIvL1VcBt8/edit
"""

# Globals-----------------------------------------------------------------------
import sys
from os.path import join, basename, isfile
import pandas as pd
import numpy as np
pd.set_option("display.max_columns", None)

## reference file configuration
#configfile: "config.yaml"
FASTA    = config['FASTA']
FASTA_BASE = basename(FASTA).split('.')[0]
INDEX    = join("reference", FASTA_BASE)
GFF      = config['GFF']
ADAPTER  = config["ADAPTER"]
GENOMESIZE = config["GENOMESIZE"]
SAMPLES   = config['SAMPLES']

MQ = config['MAP_QUALITY']
REMOVE_PLASTOME = config['REMOVE_PLASTOME']

if isinstance(REMOVE_PLASTOME, str):
    if isfile(REMOVE_PLASTOME):
        PLASTOME = REMOVE_PLASTOME
    else:
        print("{} is not existed".format(REMOVE_PLASTOME), file=sys.stderr)
        sys.exit(1)
        

# make workding directory
RAW_DIR   = join("analysis", "00-raw-data")
CLEAN_DIR = join("analysis", "01-clean-data")
ALIGN_DIR = join("analysis", "02-read-alignment")
PEAK_DIR  = join("analysis", "03-peak-calling")
BW_DIR    = join("analysis", "04-bigwig")
VIS_DIR   = join("analysis", "05-visulization")
TMPDIR    = join("analysis", "tmp")
LOGDIR    = join("analysis", "log")
QC_DIR    = join("report")

#shell("mkdir -p {CLEAN_DIR} {ALIGN_DIR} {PEAK_DIR} {BW_DIR} {VIS_DIR} {TMPDIR} {LOGDIR} {QC_DIR} ")

## The list of samples to be processed
sampleTable = pd.read_csv(SAMPLES, index_col=0, sep=",")
samples = sampleTable.index.tolist()


## SAMPLES = [ sample.strip() for sample in open(config['SAMPLES']) ] 

# output files
ALL_PEAKS = expand(join(PEAK_DIR, "{sample}_peaks.narrowPeak"), sample=samples)
ALL_BWS = expand(join(BW_DIR, "{sample}.bw"), sample=samples)

ALL_CORR_PDF = [join(VIS_DIR, FASTA_BASE, "correlation_spearman_bwscore_scatterplot.pdf"),
           join(VIS_DIR, FASTA_BASE, "correlation_spearman_bwscore_heatmapplot.pdf")
        ]

ALL_REF_PDF = [
	       join(VIS_DIR, FASTA_BASE, "scale_region.pdf"),
	       join(VIS_DIR, FASTA_BASE, "scale_region_persample.pdf"),
	       join(VIS_DIR, FASTA_BASE, "reference_point_region.pdf"),
	       join(VIS_DIR, FASTA_BASE, "reference_point_region_persample.pdf")
        ]

# utilize function
def get_files_from_sampleTable(wildcards):
    sample = wildcards.sample
    #print(sample)
    samples = sampleTable.loc[ sample, ['fq1', 'fq2']  ].to_list()
    #print(samples)
    return samples


# Rules------------------------------------------------------------------------
if GFF:
    if len(samples) > 1:
        rule all:
            input: ALL_BWS, ALL_PEAKS, ALL_CORR_PDF, ALL_REF_PDF
    else:
        rule all:
            input: ALL_BWS, ALL_PEAKS, ALL_REF_PDF
else:
    if len(samples) > 1:
        rule all:
            input: ALL_BWS, ALL_PEAKS, ALL_CORR_PDF 
    else:
        rule all:
            input: ALL_BWS, ALL_PEAKS



rule build_index:
    input: FASTA
    params:
        bt2_index_base = INDEX
    output: touch( INDEX + "_build_index.done" ) 
    threads: 10
    shell:"""
    bowtie2-build --threads {threads} {input} {params.bt2_index_base} 
    """

rule data_clean:
    input: get_files_from_sampleTable
    params:
        adapter = ADAPTER,
        length  = 50,
        quality = 20,
        prefix  = lambda wildcards: join(QC_DIR, wildcards.sample)
    output:
        r1 = temp(join(CLEAN_DIR, "{sample}_R1.fq.gz")),
        r2 = temp(join(CLEAN_DIR, "{sample}_R2.fq.gz"))
    log: join(LOGDIR, "{sample}_fastp.log")
    threads: 4
    message: "----- processing {input[0]} and {input[1]} with fastp ------"
    shell: """
	    fastp --thread {threads} \
		-a {params.adapter} \
		-i {input[0]} -I {input[1]} \
		-o {output.r1} -O {output.r2} \
		-j {params.prefix}.json \
		-h {params.prefix}.html \
		2> {log}
    """

rule read_align:
    input:
        r1 = join(CLEAN_DIR, "{sample}_R1.fq.gz"),
        r2 = join(CLEAN_DIR, "{sample}_R2.fq.gz"),
        bt_index = INDEX + "_build_index.done" 
    params:
        index  = INDEX,
        #rg     = r"@RG\tID:{sample}\tSM:{sample}\tPL:ILLUMINA",
    output:
        temp(join(ALIGN_DIR, "{sample}_sorted.bam")),
    log:
        log1=join(LOGDIR, "{sample}_align_warning.log"),
    threads: 40
    message: "align {input.r1} and {input.r2} to {params.index} with {threads} threads"
    shell:"""
        bowtie2 -p {threads} -x {params.index} \
            -1 {input.r1} -2 {input.r2} 2> {log.log1} | \
        samtools sort -@ {threads} -O bam -o {output} - && \
        samtools index -@ {threads} {output}  
    """

# mark duplication in bam
rule markdup:
    input:
        join(ALIGN_DIR, "{sample}_sorted.bam"),
    output:
        join(ALIGN_DIR, "{sample}_markdup.bam"),
    threads: 4
    log:
        join(LOGDIR, "{sample}_sambamba.log"),
    message: "----- mark duplication in {input} with sambamba ------"
    shell:"""
	    sambamba markdup -t {threads} {input} {output} 2> {log}
	"""

# filter duplcation, multi-mappers, low_quality reads with samtools
if isinstance(REMOVE_PLASTOME, str): 
    rule filtration:
        input:
            join(ALIGN_DIR, "{sample}_markdup.bam"),
        params:
            quality = MQ,
            bed = REMOVE_PLASTOME
        output:
            join(ALIGN_DIR, "{sample}_flt.bam")
        threads: 20
        message: "----- filter duplcation, multi-mappers, low_quality reads with samtools -----"
        shell:"""
            samtools view -@ {threads} -L {params.bed} -bF 1804 -q {params.quality} {input} -o {output} && \
            samtools index -@ {threads} {output} 
        """
elif isinstance(REMOVE_PLASTOME, int):
    # consider the contig small than 
    rule detect_plastome:
        input: INDEX + "_build_index.done" 
        params:
            index  = INDEX,
            len = str(REMOVE_PLASTOME)
        output: 
            bed =  INDEX + "no_plastome.bed" 
        shell: """
        bowtie2-inspect -s {params.index} | awk '$1 ~ /^Sequence/ {{if ($3 > {params.len}) print $2,"0",$3 }}' > {output}
        """
    rule filtration:
        input:
            join(ALIGN_DIR, "{sample}_markdup.bam"),
            rules.detect_plastome.output.bed
        params:
            quality = MQ
        output:
            join(ALIGN_DIR, "{sample}_flt.bam")
        threads: 20
        message: "----- filter duplcation, multi-mappers, low_quality reads with samtools -----"
        shell:"""
            samtools view -@ {threads} -L {input[1]} -bF 1804 -q {params.quality} {input[0]} -o {output} && \
            samtools index -@ {threads} {output} 
        """
else:
    rule filtration:
        input:
            join(ALIGN_DIR, "{sample}_markdup.bam"),
        params:
            quality = MQ
        output:
            join(ALIGN_DIR, "{sample}_flt.bam")
        threads: 20
        message: "----- filter duplcation, multi-mappers, low_quality reads with samtools -----"
        shell:"""
            samtools view -@ {threads} -bF 1804 -q {params.quality} {input} -o {output}  && \
            samtools index -@ {threads} {output} 
        """

# quality control
rule bam_basic_stats:
    input: rules.markdup.output
    params:
        len = str(REMOVE_PLASTOME)
    output: 
        join(QC_DIR, "{sample}_markdup_basic_stats.txt")
    shell:"""
        bam_basic_stats {input} {ouput} {params.len}
    """

rule bam_frag_size_stats:
    input: rules.filtration.output
    output:
        join(QC_DIR, "{sample}_frag_size.txt")
    shell:"""
        bam_frag_stats {input} {output}
    """


# calling peak from bam
rule call_peak:
    input:
        join(ALIGN_DIR, "{sample}_flt.bam")
    output:
        join(PEAK_DIR, "{sample}_peaks.narrowPeak")
    params:
        outdir = PEAK_DIR,
        prefix = "{sample}",
        genomesize = GENOMESIZE
    log:
        join(LOGDIR, "{sample}_macs2.log")
    shell:"""
	macs2 callpeak -t {input} -f BAMPE \
        -n {params.prefix} \
        -g {params.genomesize} \
        --outdir {params.outdir} 2> {log}
    """

# covert bam to bigwig for genome browse
rule bam2bw:
    input:
        join(ALIGN_DIR, "{sample}_flt.bam"),
    output:
        join(BW_DIR, "{sample}.bw")
    params:
        binsize = "10",
        method  = "BPM"
    threads: 10
    log:
        join(LOGDIR, "{sample}_bw.log")
    shell:"""
    bamCoverage -b {input} \
        --binSize {params.binsize} \
        --numberOfProcessors {threads} \
        -o {output} --normalizeUsing {params.method} 2> {log}
    """


# visualization coverage base on the GFF
if GFF:
    rule gff2bed:
        input:
            gff = GFF
        output:
            bed = join(VIS_DIR, FASTA_BASE, "gene.bed")
        shell: """
        awk 'BEGIN{{OFS="\t"}}; $3 == "gene" {{match($9, /ID=(.*?);/,ID) ; print $1,$4-1,$5,ID[1],$6,$7 }}' {input} > {output}
        """

    # compute matrix
    rule scale_region_matrix:
        input:
            bw = ALL_BWS,
            bed = rules.gff2bed.output.bed
        output:
            join(VIS_DIR, FASTA_BASE, "matrix_scale_region.scale.gz")
        threads: 20
        shell:"""
        computeMatrix scale-regions \
        -S {input.bw} \
        -R {input.bed} \
        --regionBodyLength 2000 \
        --beforeRegionStartLength 3000 --afterRegionStartLength 3000 \
        --skipZeros --numberOfProcessors {threads} \
        -o {output}
        """

    rule reference_point_matrix:
        input:
            bw = ALL_BWS,
            bed = rules.gff2bed.output.bed
        output:
            join(VIS_DIR, FASTA_BASE, "mmatrix_reference_point.reference.gz")
        threads: 20
        shell:"""
        computeMatrix reference-point \
        -S {input.bw} \
        -R {input.bed} \
        --referencePoint TSS \
        -b 3000 -a 3000 \
        --skipZeros --numberOfProcessors {threads} \
        -o {output}
        """

    rule matrix_vis:
        input:
            m1 = rules.scale_region_matrix.output,
            m2 = rules.reference_point_matrix.output
        output:
            p1 = join(VIS_DIR, FASTA_BASE, "scale_region.pdf"),
            p2 = join(VIS_DIR, FASTA_BASE, "scale_region_persample.pdf"),
            p3 = join(VIS_DIR, FASTA_BASE, "reference_point_region.pdf"),
            p4 = join(VIS_DIR, FASTA_BASE, "reference_point_region_persample.pdf")
        shell:"""
        plotProfile -m {input.m1} -out {output.p1} --perGroup
        plotProfile -m {input.m1} -out {output.p2} --numPlotsPerRow 4
        plotProfile -m {input.m2} -out {output.p3} --perGroup
        plotProfile -m {input.m2} -out {output.p4} --numPlotsPerRow 4
        """

## multiple sample correlationship analysis
if len(samples) > 1:
    rule multi_bigwig_summary:
        input:
            ALL_BWS
        output:
            join(VIS_DIR, FASTA_BASE, "multibw_results.npz")
        threads: 20
        shell:"""
        multiBigwigSummary bins -b {input} \
        --numberOfProcessors {threads} \
        -o {output}
        """

    rule correlation_scatter:
        input:
            join(VIS_DIR, FASTA_BASE, "multibw_results.npz")
        output:
            join(VIS_DIR, FASTA_BASE, "correlation_spearman_bwscore_scatterplot.pdf")
        shell:"""
        plotCorrelation -in {input} \
        --corMethod spearman --skipZeros \
        --whatToPlot scatterplot \
        --plotTitle "Spearman Correlation" \
        --removeOutliers \
        --plotFile {output}
        """

    rule correlation_heatmap:
        input:
            join(VIS_DIR, FASTA_BASE, "multibw_results.npz")
        output:
            join(VIS_DIR, FASTA_BASE, "correlation_spearman_bwscore_heatmapplot.pdf")
        shell:"""
        plotCorrelation -in {input} \
        --corMethod spearman --skipZeros \
        --whatToPlot heatmap \
        --plotTitle "Spearman Correlation" \
        --removeOutliers \
        --plotNumbers \
        --plotFile {output}
        """

